{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 中间表示中的类型节点\n",
    "\n",
    "参考：`tvm/tests/python/ir/test_ir_type.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "ROOT = Path(\".\").resolve().parents[2]\n",
    "sys.path.extend([f\"{ROOT}/tests\", f\"{ROOT}/src\"])\n",
    "# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span\n",
    "from tools.torch_utils import verify_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "\n",
    "def check_json_roundtrip(node):\n",
    "    json_str = tvm.ir.save_json(node)\n",
    "    back = tvm.ir.load_json(json_str)\n",
    "    assert tvm.ir.structural_equal(back, node, map_free_vars=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def test_prim_type():\n",
    "    x = tvm.ir.PrimType(\"int32\")\n",
    "    assert isinstance(x, tvm.ir.PrimType)\n",
    "    assert x.dtype == \"int32\"\n",
    "\n",
    "\n",
    "def test_tensor_type_bad_constructor():\n",
    "    try:\n",
    "        x = tvm.ir.TensorType(\"xx\", \"xx\")\n",
    "    except tvm.error.TVMError:\n",
    "        pass\n",
    "\n",
    "\n",
    "def test_tensor_type():\n",
    "    shape = tvm.runtime.convert([1, 2, 3])\n",
    "    dtype = \"float32\"\n",
    "    tt = tvm.ir.TensorType(shape, dtype)\n",
    "    assert tt.dtype == dtype\n",
    "    assert tt.shape == shape\n",
    "    assert tt.span == None\n",
    "    str(tt)\n",
    "    check_json_roundtrip(tt)\n",
    "\n",
    "\n",
    "def test_type_param():\n",
    "    tp = tvm.ir.TypeVar(\"name\", tvm.ir.TypeKind.Type)\n",
    "    assert tp.kind == tvm.ir.TypeKind.Type\n",
    "    # assert tp.span  # TODO allow us to set span\n",
    "    str(tp)\n",
    "    check_json_roundtrip(tp)\n",
    "\n",
    "\n",
    "def test_func_type():\n",
    "    type_params = tvm.runtime.convert([])\n",
    "    type_constraints = tvm.runtime.convert([])  # TODO: fill me in\n",
    "    arg_types = tvm.runtime.convert([])\n",
    "    ret_type = tvm.ir.TensorType((1, 2, 3), \"float32\")\n",
    "    tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)\n",
    "    assert tf.type_params == type_params\n",
    "    assert tf.type_constraints == type_constraints\n",
    "    assert tf.arg_types == arg_types\n",
    "    assert tf.ret_type == ret_type\n",
    "    assert tf.span == None\n",
    "    # TODO make sure we can set span\n",
    "    str(tf)\n",
    "    check_json_roundtrip(tf)\n",
    "\n",
    "\n",
    "def test_tuple_type():\n",
    "    tp = tvm.ir.TypeVar(\"tp\", tvm.ir.TypeKind.Type)\n",
    "    tf = tvm.ir.FuncType([], tvm.ir.TupleType([]), [], [])\n",
    "    tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), \"float32\")\n",
    "    fields = tvm.runtime.convert([tp, tf, tt])\n",
    "\n",
    "    tup_ty = tvm.ir.TupleType(fields)\n",
    "    assert tup_ty.fields == fields\n",
    "    str(tup_ty)\n",
    "    check_json_roundtrip(tup_ty)\n",
    "\n",
    "\n",
    "def test_type_relation():\n",
    "    tp = tvm.ir.TypeVar(\"tp\", tvm.ir.TypeKind.Type)\n",
    "    tf = tvm.ir.FuncType([], None, [], [])\n",
    "    tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), \"float32\")\n",
    "    args = tvm.runtime.convert([tp, tf, tt])\n",
    "\n",
    "    num_inputs = 2\n",
    "    func = tvm.ir.EnvFunc.get(\"tvm.relay.type_relation.Broadcast\")\n",
    "    attrs = tvm.ir.make_node(\"attrs.TestAttrs\", name=\"attr\", padding=(3, 4))\n",
    "\n",
    "    tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)\n",
    "    assert tr.args == args\n",
    "    assert tr.num_inputs == num_inputs\n",
    "    str(tr)\n",
    "    check_json_roundtrip(tr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
